using Pkg
Pkg.add(["Distributed", "Distributions", "LinearAlgebra", "SpecialFunctions", "GSL", "DataFrames", "CSV"])

# It is important throughout that subnational regions cover rows and countries over columns. 
# when iterating over regions, first iterate over all regions in the first country then move on
# to the next country. Same goes in tau mat


using Distributed

if nprocs() == 1
    addprocs(9)
end


@everywhere using Distributed, Distributions, LinearAlgebra, SpecialFunctions, GSL, DataFrames, CSV

# If false the first region will have a different intra regional trade cost from the rest, which will have an intra-regional trade cost of 1. 
@everywhere IO_flag = false

cd("$(homedir())/Documents/Personal Research/JAERE Extension/Draft/Code/appendix_code")

inter_trade_costs = range(1.0, stop = 1.2, length = 9)

@everywhere function intra_solve(iceberg, IO_flag)
    output_df = DataFrame(
        phi = Float64[],
        del = Float64[],
        reg = Float64[],
        country = Int64[],
        subnat = Int64[],
        loc_prob = Float64[],
        entry = Int64[],
        Phi_network = Float64[],
        Del_network = Float64[],
        kappa = Float64[],
        dirty_input = Float64[],
        emissions = Float64[],
        ttl_output = Float64[],
        FA = Float64[],
        e_intensity = Float64[],
        inter_trade_cost = Float64[],
        intra_trade_cost = Float64[],
        emissions_tax = Float64[],
        Del_H = Float64[],
        mass_cus = Float64[],
        mass_sup = Float64[],
        int_sold_q = Float64[],
        int_purch_q = Float64[],
        int_sold_val = Float64[],
        int_purch_val = Float64[],
        g_chi = Float64[])

    intra_trade_costs = [1.17] 
    
    for intra in intra_trade_costs
        t = .439
        a_LC = 0.02
        A = 1.0
        w = 1.0             # price of dirty input
        eps = 0.034           # scale of emissions generation
        b = 0.000001            # denominator addition for abatement
        max_kap = Inf         # maximum unit price of dirty input

        # Who gets the tax revenue? 
        sig = 4.0       # elasticity of substitution
        if IO_flag == true
            alp = 1.
        else
            alp = 1.        # input suitability
        end
        a_io = 0.1
        m_phi = 0.0          # mean of fundamental productivity
        m_del = 0.          # mean of fundamental quality
        v_phi = 1.          # std of fundamental productivity
        v_del = 1.         # std of fundamental quality
        v_cor = 0.          # correlation between phi and del
        
        m_psi = 0.216
        m_psi_F = 0.6

        s_psi = 0.          # elasticity of psi function
        s_xi  = 0.9565       # measure of xi variance
        nu = 0.              #.1641          # relationship stickiness
        D = 1.              # size of dirty input
        bet = 0.            # discount factor
        mu = sig / (sig-1)    # markup
        entry_cost = 0

        m_xi = 1.0
        lam_xi = m_xi / gamma(1 + 1 / s_xi) # CHECK LIM CODE

        qt_min = 0.05         # min quantile for phi/del distribution
        qt_max = 0.95        # max quantile for phi/del distribution
    
        n_grid = 5      # grid size
        tol = 0.001          # numerical tolerance for iterations
        max_iter = 200      # max iterations
        ondisp = 0          # iteration display toggle
        update = 0.9

        
        n_grid_phi = n_grid
        n_grid_del = n_grid
        n_grid_subnat = 2
        n_grid_reg = 4 
        n_grid_cntry = Int(n_grid_reg/n_grid_subnat)

        A_phi = ones(n_grid_phi, n_grid_del, n_grid_reg) 
 
 
        A_del = ones(n_grid_phi, n_grid_del, n_grid_reg) 
 
        # Account for different trade and matching costs
        psi_D = m_psi .* ones(n_grid_subnat, n_grid_subnat)
        psi_F = m_psi_F .* ones(n_grid_subnat, n_grid_subnat)
        psi_mat = [psi_D psi_F;
                    psi_F psi_D]

        tau_mat_intra = intra .* ones(n_grid_subnat,n_grid_subnat)        # Could generalize to not be the same between regions 
        tau_mat_intra[diagind(tau_mat_intra)] .= 1  # replace diagonal with 1  
        tau_mat_inter = intra .* iceberg .* ones(n_grid_subnat, n_grid_subnat) 
        tau_mat_inter[diagind(tau_mat_inter)] .= [intra^2 * iceberg, iceberg]  # replace diagonal with 1  
        tau_mat = [tau_mat_intra tau_mat_inter;
                    tau_mat_inter tau_mat_intra]

        # note that the origin region corresponds with rows, destination corresponds to columns
            # ih                ch                  if          cf
        # ih  1                 intra               if (intra^2 * inter)  intra * inter
        # ch intra              1                   cf inter * intra      inter
        # if (intra^2 * inter)  intra * inter
        # cf inter * intra      inter


        function matmean(D,g)
            #compute product of data and density
            Dg = D.*g
            #Compute mean
            sum(Dg[.!Base.isnan.(Dg)])
        end
            
        function matmean(D,g,renorm)
            if renorm == 1
                g = g/sum(g[.!Base.isnan.(D)])
            end 
            #compute product of data and density
            Dg = D.*g
            #Compute mean
            return sum(Dg[.!Base.isnan.(Dg)])
        end

        # CDF of the matching function
        function G_xi(x) 
            cdf.(Weibull(s_xi, lam_xi), x)
        end

        # The partial expectation of the matching function
        function E_xi(A)
            E_xi_mat = lam_xi .* (gamma(1+1/s_xi) .- GSL.sf_gamma_inc.(1 .+ (1 ./s_xi), (A ./ lam_xi).^s_xi))
            return E_xi_mat
        end

        qt_grid = range(qt_min, stop=qt_max, length=n_grid_phi)
        phi = quantile.(LogNormal(m_phi,v_phi),qt_grid)
        phi = repeat(phi, inner = (1, n_grid_del, n_grid_reg))
    

        qt_grid = range(qt_min, stop = qt_max, length=n_grid_del)
        del = quantile.(LogNormal(m_del,v_del), qt_grid)
        del = repeat(del', inner = (n_grid_phi, 1, n_grid_reg))
        
        M = [m_phi, m_del]              
        V = [v_phi^2 v_cor*v_phi*v_del ; v_cor*v_phi*v_del v_del^2]
        g_chi_fun(phi,del) = pdf(MvLogNormal(M,V), [phi;del])
        # g_chi_alt(phi,del) = pdf(MvLogNormal(M_alt,V_alt), [phi;del])
        g_chi_guess = zeros(n_grid_phi, n_grid_phi, n_grid_reg)
        for n_phi in 1:n_grid_phi
            for n_del in 1:n_grid_del
                for n_reg in 1:n_grid_reg
                    if phi[n_phi, n_del, n_reg] == 0 || del[n_phi, n_del, n_reg] == 0
                        g_chi_guess[n_phi, n_del, n_reg] = 0
                    else
                        g_chi_guess[n_phi, n_del, n_reg] = g_chi_fun(phi[n_phi, n_del, n_reg], del[n_phi, n_del, n_reg])
                    end
                end
            end
        end
        g_chi_guess = g_chi_guess ./ (sum(g_chi_guess)) #n_grid_reg * 
    

        # get probability of location: 
        function loc_prob_balance(loc_prob, entry)
            # within a country if a firm is located on the coast, this defines it's probability of locating inland
            loc_prob = reshape(loc_prob, (n_grid_phi, n_grid_del, n_grid_subnat, n_grid_cntry))
            loc_prob[:, :, 1, :] = 1 .- loc_prob[:, :, 2, :]
            loc_prob = reshape(loc_prob, (n_grid_phi, n_grid_del, n_grid_reg))
            loc_prob = loc_prob .* entry
            return loc_prob
        end 

        function FA_iter(F_iter, Del_H_iter, Del_iter, entry, loc_prob)
            tmp_phi = phi .* entry .* loc_prob
            F_new = 1 ./ w .* ((t .* eps .*(sig .- 1) .* tmp_phi.^(sig .- 1) .*
                (b .+ F_iter) .^ (sig .- 2) .* (mu - 1) .* Del_H_iter .*
                Del_iter) .^ (1/sig) .- t .* eps) 
            F_new = max.(F_new, 0.0000000000000000001)   #overkill to keep f from being zero                
            F_new[entry .== 0] .= 0
            return F_new
        end


        function DelHIter(Phi_sol, Del_sol, Del_H_iter, FA, g_chi, entry, loc_prob)
            
            D_f_C = zeros(n_grid_phi, n_grid_del, n_grid_reg, n_grid_phi,n_grid_del, n_grid_reg)
                
            #Compute labor used for relationship costs with customers for 
            #each firm
            
            for n_phi in 1:n_grid_phi
                for n_del in 1:n_grid_del
                    for n_reg in 1:n_grid_reg
                        
                        tau = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        psi = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        for d_reg in 1:n_grid_reg
                            tau[:,:,d_reg] = tau_mat[n_reg, d_reg] .* ones(n_grid_phi, n_grid_del, 1)
                            psi[:,:,d_reg] = psi_mat[n_reg, d_reg] .* ones(n_grid_phi, n_grid_del, 1)
                        end

                        #matrix of profit generated with customers C
                        pi_C = mu^(-sig) * (mu-1) * (tau ./ alp).^(1-sig) .* Del_H_iter .* Phi_sol[n_phi,n_del,n_reg] .* Del_sol

                        #matrix of maximum xi values for which relationships 
                        #with customers are active
                        xi_max_C = max.((pi_C ./ psi .- bet*nu) ./ (1-bet*nu),0)
                        
                        #matrix of acceptance probabilities with customers C
                        a_C = G_xi(xi_max_C)
                        if IO_flag == true
                            a_C .= a_io
                        end
                            
                        #matrix of matching probabilities with customers C
                        m_C = a_C
                            
                        #matrix of tatal relationship fixed cost with customers
                        xi_bar_C = E_xi(xi_max_C)
                            
                        #=check for Inf values of s_xi.*exp(b_xi.*xi_max_c) and 
                        #replace corresponding expint with zero. 
                        Inf_index = s_xi.*exp(b_xi.*xi_max_C)==Inf
                        xi_bar_C[Inf_index] = = -xi_max_C[Inf_index].*(1-G_xi(xi_max_C[Inf_index])) + exp(s_xi)./b_xi.*expint(s_xi);
                        #Not sure if all of this is applicable in SS???
                            
                        # labor used for relationship costs with customers (C)
                        L_f_C[n_phi,n_del] = nu*matmean(m_C.*psi_mat,g_chi,0) + (1-nu)*matmean(psi_mat.*xi_bar_C,g_chi,0)=#
                        D_f_C[:,:,:,n_phi,n_del,n_reg] = (nu*m_C + (1-nu)*xi_bar_C) .* psi
                    end
                end
            end
            
            kap = w .+ ((t * eps) ./ (b .+ FA))     
            kap = min.(kap, max_kap)
            kap = kap .* entry .* loc_prob

            # #aggregate labor used for relationship fixed costs
            D_f = zeros(n_grid_phi,n_grid_del,n_grid_reg)
            for n_phi in 1:n_grid_phi
                for n_del in 1:n_grid_del
                    for n_reg in 1:n_grid_reg
                        D_f[n_phi,n_del,n_reg] = matmean(D_f_C[:,:,:,n_phi,n_del,n_grid_reg],g_chi,0)
                    end
                end
            end    
            # Assuming that the Dirty input pays the fixed match cost and abatement,  .+ FA ).* kap
            # don't forget price of the dirty input
            D_f_bar = matmean((D_f .+ FA) .* kap, g_chi, 0)  # 
                
            tmp_phi = entry .* loc_prob .* (phi .* A_phi).^(sig-1)
            
            #new guess of Del_H
            Del_H_new = max.((D)/ matmean(Del_sol .* (kap .^ (-sig)) .* tmp_phi, g_chi, 0), 0.001) 
            RC = D_f + FA
            RC = RC .* entry .* loc_prob

            # NOTE THAT IF WE MAKE DIRTY INPUT SPECIFIC TO EACH REGION THEN THIS SHOULD BE A VECTOR, NOT A CONSTANT
            out = Dict()
            out["Del_H_new"] = Del_H_new
            out["RC"] = RC
            return out
        end

        function PhiDelIter(Phi_Iter,Del_Iter,Del_H_iter,FA, g_chi, entry, loc_prob)
            #initialize new Phi and Del
            Phi_new = zeros(n_grid_phi,n_grid_del,n_grid_reg)
            Del_new = zeros(n_grid_phi,n_grid_del,n_grid_reg)

            #compute new Phi and Del at each grid point
            for n_phi in 1:n_grid_phi #Lim has crazy stuff right here
                for n_del in 1:n_grid_del
                    for n_reg in 1:n_grid_reg
                        
                        kap = w .+ (eps .* t ./ (b .+ FA[n_phi, n_del, n_reg]))
                        kap = min.(kap, max_kap)

                        tau_C = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        tau_S = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        psi_C = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        psi_S = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        for d_reg in 1:n_grid_reg
                            tau_C[:,:,d_reg] = tau_mat[n_reg, d_reg] .* ones(n_grid_phi, n_grid_del, 1)
                            tau_S[:,:,d_reg] = tau_mat[d_reg, n_reg] .* ones(n_grid_phi, n_grid_del, 1)
                            psi_C[:,:,d_reg] = psi_mat[n_reg, d_reg] .* ones(n_grid_phi, n_grid_del, 1)
                            psi_S[:,:,d_reg] = psi_mat[d_reg, n_reg] .* ones(n_grid_phi, n_grid_del, 1)
                        end

                        con = A * mu^(-sig)*(mu-1)*alp^(sig-1) * Del_H_iter
                        #matricies of profit generated with all suppliers and customers
                        pi_S = con .* tau_S .^ (1 - sig) .* Phi_Iter .* Del_Iter[n_phi, n_del, n_reg]
                        pi_C = con .* tau_C .^ (1 - sig) .* Phi_Iter[n_phi, n_del, n_reg] .* Del_Iter
                            
                        
                        xi_max_S = max.((pi_S ./ psi_S .- bet*nu) ./ (1 .- bet*nu),0) 
                        xi_max_C = max.((pi_C ./ psi_C .- bet*nu) ./ (1 .- bet*nu),0)
                            
                        #matricies of acceptance probabilities with all suppliers and customers
                        a_S = G_xi(xi_max_S)
                        a_C = G_xi(xi_max_C)

                        if IO_flag == true
                            a_S .= a_io
                            a_C .= a_io
                        end
                            
                        #matricies of matching probabilities with all suppliers and customers
                        m_S = a_S
                        m_C = a_C
                            
                        fundamental_prod = entry[n_phi, n_del, n_reg] .* loc_prob[n_phi, n_del, n_reg] .*
                            (phi[n_phi, n_del, n_reg]*A_phi[n_phi, n_del, n_reg] ./ kap)^(sig-1)
                        fundamental_dem = entry[n_phi, n_del, n_reg] .* loc_prob[n_phi, n_del, n_reg] .*
                            mu^(-sig)*(del[n_phi, n_del, n_reg]*A_del[n_phi, n_del, n_reg])^(sig-1)
                        #new guess of Phi and Del
                        Phi_new[n_phi, n_del, n_reg] = fundamental_prod + (alp./mu)^(sig-1)*matmean(Phi_Iter.* m_S .* tau_S .^ (1-sig), g_chi,0)
                        Del_new[n_phi,n_del,n_reg] = fundamental_dem + mu^(-sig)*alp^(sig-1)*matmean(Del_Iter.*m_C .* tau_C .^ (-sig),g_chi,0)
                    end
                end
            end
            output = Dict()
            output["Phi"] = Phi_new
            output["Del"] = Del_new
            return output
        end
     

        function Solver(entry, loc_prob, Del_H_guess, Phi_guess, Del_guess, F_guess)
            # Dictionary to store final output
            out = Dict()

            F_guess[entry .== 0] .= 0
            Del_guess[entry .== 0] .= 0
            Phi_guess[entry .== 0] .= 0
            
            F_res = Inf
            F_niter = 0
                
            # Update the entry and location 
            g_chi = g_chi_guess .* loc_prob
            g_chi = g_chi .* entry
            
            # Iterate over F
            while F_res > tol && F_niter <= max_iter
                F_niter = F_niter + 1
                
                if F_niter == 1
                    FA = F_guess
                    out["F"] = FA
                else
                    FA = out["F"]
                end

                FA[entry .== 0] .= 0
                Del_H = Del_H_guess
                    
                #Initialize Del_H risidual and iteration counter
                Del_H_res = Inf 
                Del_H_niter = 0 

                while Del_H_res > tol && Del_H_niter <= max_iter
                        
                    #Increment and display iteration number
                    Del_H_niter = Del_H_niter + 1
                    
                    #Initialize guess of Phi and Del on the first iteration
                    
                    if Del_H_niter == 1
                        Phi = Phi_guess
                        Del = Del_guess
                    else 
                        Phi = out["Phi"]
                        Del = out["Del"]
                    end

                    PhiDel_res = Inf
                    PhiDel_iter = 0
                    
                    while PhiDel_res > tol && PhiDel_iter <= max_iter
                        #incriment ieration counter
                        PhiDel_iter = PhiDel_iter + 1
                            
                        #Iterate on (Phi,Del) functional eq. to compute new guess of (Phi,Del)
                        out_PhiDel = PhiDelIter(Phi,Del,Del_H,FA,g_chi,entry,loc_prob) 
                        Phi_new = out_PhiDel["Phi"]
                        Phi_new[entry .== 0] .= 0
                        Del_new = out_PhiDel["Del"]
                        Del_new[entry .== 0] .= 0

                        #Compute PhiDel Residual 
                        Phi_res = maximum(abs.(Phi-Phi_new))
                        Del_res = maximum(abs.(Del-Del_new))
                        PhiDel_res = max(Phi_res,Del_res) 
                            
                        #Update guess of Phi and Del if needed 
                        if PhiDel_res > tol
                            Phi = Phi_new
                            Del = Del_new 
                        else
                            out["Phi"] = Phi
                            out["Del"] = Del
                        end
                        
                    end

                    if PhiDel_res > tol
                        println("warning: Phi/Del iteration failed to converge")
                        sleep(1)   
                    end

                    Del_H_dict = DelHIter(Phi,Del,Del_H,FA,g_chi,entry,loc_prob) 
                    
                    Del_H_new = Del_H_dict["Del_H_new"]
                    RC = Del_H_dict["RC"]

                    #Compute and display Del_H Residual
                    Del_H_res = abs.(Del_H - Del_H_new)

                    #Update guess of Del_H if needed
                    if Del_H_res > tol
                        Del_H = update .* Del_H_new .+ (1 .- update) .* Del_H
                    else
                        out["Del_H"] = Del_H
                        out["RC"] = RC
                    end

                end
                
                if Del_H_res > tol
                    println("warning: Del_H iteration failed to converge")
                    sleep(1)
                end

                #Initialize PhiDel residual and iteration counter
                Del = out["Del"]
                Del_H = out["Del_H"]
                RC = out["RC"]

                #Compute new guess of Del_H
                F_new = FA_iter(FA, Del_H, Del, entry, loc_prob)
                F_new[entry .== 0] .= 0
                F_res = maximum(abs.(FA - F_new))
                # println("FA residual: $F_res")
                F_new = update .* F_new .+ (1 - update) .* FA
                out["F"] = F_new    
                    
            end

            if F_res > tol && F_niter >= max_iter
                println("F DID NOT CONVERGE")
            end 

            Phi = out["Phi"]
            Del = out["Del"]
            Del_H = out["Del_H"]
            FA = out["F"]
            RC = out["RC"]

            kap = w .+ ((t * eps) ./ (b .+ FA))     
            kap_out = min.(kap, max_kap)
            out["kap"] = kap_out
 
            dirty_input = (Del_H .* Del .* (kap_out) .^ (-sig) .* phi .^ (sig - 1))

            out["dirty_input"] = dirty_input
            emissions = eps .* dirty_input ./ (b .+ FA)
            out["emissions"] = emissions
            ttl_output = Del_H .* Del .* Phi .^ (sig/(sig - 1))
            out["ttl_output"] = ttl_output
            out["e_intensity"] = emissions ./ ttl_output
            var_profit = (mu - 1) .* Del_H .* Del .* Phi 
            out["var_profit"] = var_profit
            out["profit"] = var_profit .- RC .- FA 
            return(out)
        end


        #set initial conditions for entry
        entry_guess = ones(n_grid_phi, n_grid_del, n_grid_reg)
        # Declare coastal probabilities as zero
        loc_prob_guess = 0.75 .* ones(n_grid_phi, n_grid_del, n_grid_reg)
        # function defines inland probabilities as residual of coastal ones
        loc_prob_guess = loc_prob_balance(loc_prob_guess, entry_guess) 

        F_guess = 0.0000001 .* ones(n_grid_phi,n_grid_del,n_grid_reg)

        #Initial guess for Del_H
        prod_1 = (phi .* A_phi).^(sig-1)
        prod_2 = ((del .* A_del)).^(sig-1)
        
        #Del_H for empty network
        phidel_bar = mean(prod_1 .* prod_2) 
        Del_H_emp = (mu ^ sig) * D / phidel_bar
            
        #Set initial conditions
        Del_H_guess = Del_H_emp


        #set initial conditions           
        Phi_guess = (phi.*A_phi).^(sig-1)
        Del_guess = mu^(-sig).*(del.*A_del).^(sig-1)


        # Iterate over location choices and entry. Update the guess to be previous soluiton. 
        # Get rank of firms assuming everyone locates inland and enters
        guess = Solver(entry_guess, loc_prob_guess, Del_H_guess, Phi_guess, Del_guess, F_guess)
        profit_orig = guess["profit"]
        profit_orig = reshape(profit_orig, (n_grid_phi, n_grid_del, n_grid_subnat, n_grid_cntry))
        A = profit_orig[:, :, 1, 1]
        rankA = sort(reshape(A, (n_grid_phi*n_grid_del)), rev=true)
        rankA = indexin(rankA, A)
        
        # adjust entry based off cutoff
        entry_guess[(A .< entry_cost), :] .= 0
        loc_prob_guess[entry_guess .== 0] .= 0

        iter = 0
        entry_res = 1
        loc_res = 1

        Del_H = guess["Del_H"]
        Del = guess["Del"]
        Phi = guess["Phi"]
        FA = guess["F"]
        kap_out = guess["kap"]

        entry = copy(entry_guess)
        loc_prob = copy(loc_prob_guess)
        res = 1

        while res > tol && (iter < max_iter)
    
            iter += 1

            entry_new = copy(entry)
            loc_prob_new = copy(loc_prob)

            # Calculate a firm's profit of a unilateral deviation in location and entry
            counter = 0
            for idx in rankA
                counter += 1                
                # if entry[idx, 1] == 1
                    if counter == 1
                        loc_prob_coastal = copy(loc_prob)
                        loc_prob_inland = copy(loc_prob)
                    else
                        loc_prob_coastal = copy(loc_prob_new)
                        loc_prob_inland = copy(loc_prob_new)
                    end
                    loc_prob_inland[idx, [1, 3]] .= 1 # This line (or chunk) will need to be adjusted if more than 2 subnat regions
                    loc_prob_inland[idx, [2, 4]] .= 0 # This line (or chunk) will need to be adjusted if more than 2 subnat regions
                    loc_prob_coastal[idx, [2, 4]] .= 1
                    loc_prob_coastal[idx, [1, 3]] .= 0
                    
                    inland = Solver(entry, loc_prob_inland, Del_H, Phi, Del, FA)
                    coastal = Solver(entry, loc_prob_coastal, Del_H, Phi, Del, FA)

                    inland_profit = inland["profit"]
                    coastal_profit = coastal["profit"]

                    # Calculate entry (if countries are asymmetrical this will need to be adjusted)
                    if inland_profit[idx, 1] > entry_cost 
                        entry_new[idx, :] .= 1
                        # Probability of locating on the coast for the new firm
                        firm_coast_loc_prob = G_xi((coastal_profit[idx, 2] - inland_profit[idx, 1])/ a_LC)
                        loc_prob_new[idx, [2, 4]] .= firm_coast_loc_prob
                        loc_prob_new[idx, [1, 3]] .= 1 - firm_coast_loc_prob
                    else
                        entry_new[idx, :] .= 0
                        loc_prob_new[idx, :] .= 0
                    end
                # end
            end

            entry_res = maximum(abs.(entry_new - entry))
            loc_res = maximum(abs.(loc_prob_new - loc_prob))

            res = maximum([entry_res, loc_res])
            
            if res > tol
                entry = copy(entry_new)
                loc_prob = copy(loc_prob_new)
            end

            println("Iteration $iter residual: $loc_res")

            # run with new locations and entry probs to update guesses/solutions
            results = Solver(entry, loc_prob, Del_H, Phi, Del, FA)
            Del_H = results["Del_H"]
            Del = results["Del"]
            Phi = results["Phi"]
            FA = results["F"]
            kap_out = results["kap"]
        end

        results = Solver(entry, loc_prob, Del_H, Phi, Del, FA)
        Del_H = results["Del_H"]
        Del = results["Del"]
        Phi = results["Phi"]
        

        function intermediate_inputs(Del_H, Phi, Del, g_chi)
            Output = Dict()

            mass_cus = zeros(n_grid_phi, n_grid_del, n_grid_reg)
            mass_sup = zeros(n_grid_phi, n_grid_del, n_grid_reg)
            
            intermediates_sold_q = zeros(n_grid_phi, n_grid_del, n_grid_reg)
            intermediates_purchased_q = zeros(n_grid_phi, n_grid_del, n_grid_reg)

            intermediates_sold_val = zeros(n_grid_phi, n_grid_del, n_grid_reg)
            intermediates_purchased_val = zeros(n_grid_phi, n_grid_del, n_grid_reg)

            for n_phi in 1:n_grid_phi
                for n_del in 1:n_grid_del
                    for n_reg in 1:n_grid_reg
                    
                        tau_C = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        tau_S = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        psi_C = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        psi_S = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        for d_reg in 1:n_grid_reg
                            tau_C[:,:,d_reg] = tau_mat[n_reg, d_reg] .* ones(n_grid_phi, n_grid_del, 1)
                            tau_S[:,:,d_reg] = tau_mat[d_reg, n_reg] .* ones(n_grid_phi, n_grid_del, 1)
                            psi_C[:,:,d_reg] = psi_mat[n_reg, d_reg] .* ones(n_grid_phi, n_grid_del, 1)
                            psi_S[:,:,d_reg] = psi_mat[d_reg, n_reg] .* ones(n_grid_phi, n_grid_del, 1)
                        end

                        con = mu^(-sig)*(mu-1)*alp^(sig-1) * Del_H
                        #matricies of profit generated with all suppliers and customers
                        pi_S = con .* tau_S .^ (1 - sig) .* Phi .* Del[n_phi, n_del, n_reg]
                        pi_C = con .* tau_C .^ (1 - sig) .* Phi[n_phi, n_del, n_reg] .* Del
                            
                        xi_max_S = max.((pi_S ./ psi_S .- bet*nu) ./ (1 .- bet*nu),0) 
                        xi_max_C = max.((pi_C ./ psi_C .- bet*nu) ./ (1 .- bet*nu),0)
                            
                        #matricies of acceptance probabilities with all suppliers and customers
                        a_S = G_xi(xi_max_S)
                        a_C = G_xi(xi_max_C)
                            
                        if IO_flag == true
                            a_S .= a_io 
                            a_C .= a_io
                        end

                        #matricies of matching probabilities with all suppliers and customers
                        m_S = a_S
                        m_C = a_C
                        
                        mass_sup[n_phi, n_del, n_reg] = matmean(m_S, g_chi)
                        mass_cus[n_phi, n_del, n_reg] = matmean(m_C, g_chi)
                        

                        intermediates_sold_q[n_phi, n_del, n_reg] = Del_H * mu ^ (-sig) * 
                            Phi[n_phi, n_del, n_reg] ^ (sig / (sig - 1)) * 
                            matmean(m_C .* tau_C .^ (-sig) .* Del, g_chi)
                        intermediates_purchased_q[n_phi, n_del, n_reg] = Del_H * mu ^ (-sig) * 
                            Del[n_phi, n_del, n_reg] * 
                            matmean(m_S .* tau_S .^ (-sig) .* Phi .^ (sig / (sig - 1)), g_chi)

                        intermediates_sold_val[n_phi, n_del, n_reg] = Del_H * mu ^ (1 - sig) * 
                            Phi[n_phi, n_del, n_reg] * matmean(m_C .* tau_C .^ (1 - sig) .* Del, g_chi)
                        intermediates_purchased_val[n_phi, n_del, n_reg] = Del_H * mu ^ (1 - sig) * 
                            Del[n_phi, n_del, n_reg] * matmean(m_S .* tau_S .^ (1 - sig) .* Phi, g_chi)
                    end
                end
            end
            Output["mass_sup"] = mass_sup
            Output["mass_cus"] = mass_cus
            Output["intermediates_sold_q"] = intermediates_sold_q
            Output["intermediates_purchased_q"] = intermediates_purchased_q
            Output["intermediates_sold_val"] = intermediates_sold_q
            Output["intermediates_purchased_val"] = intermediates_purchased_q
            return Output 
        end

        function region_id(n_regions)
            reg = zeros(n_grid_phi, n_grid_del, n_regions)
            for reg_id in 1:n_regions
                reg[:,:,reg_id] .= reg_id
            end    
            return reg[:]
        end

        intermediate_info = intermediate_inputs(Del_H, Phi, Del, g_chi_guess .* entry .* loc_prob)

        country = repeat(1:n_grid_cntry, inner = n_grid_phi * n_grid_del * n_grid_subnat)
        subnat = repeat(1:n_grid_subnat, inner =n_grid_phi * n_grid_del, outer = n_grid_cntry)
        
        FA = results["F"]
        kap_out = results["kap"]
        dirty_input = results["dirty_input"]
        emissions = results["emissions"]
        ttl_output = results["ttl_output"]
        e_intensity = results["e_intensity"]

        f_output = DataFrame(
            phi = phi[:],
            del = del[:],
            loc_prob = loc_prob[:],
            entry = entry[:], 
            country = country,
            subnat = subnat,
            reg = region_id(n_grid_reg),
            Phi_network = Phi[:],
            Del_network = Del[:],
            kappa = kap_out[:],
            dirty_input = dirty_input[:],
            emissions = emissions[:],
            ttl_output = ttl_output[:],
            FA = FA[:],
            e_intensity = e_intensity[:],
            inter_trade_cost = iceberg .* ones(n_grid_phi * n_grid_del * n_grid_reg),
            intra_trade_cost = intra .* ones(n_grid_phi * n_grid_del * n_grid_reg),
            emissions_tax = t .* ones(n_grid_phi * n_grid_del * n_grid_reg), 
            Del_H = Del_H .* ones(n_grid_phi * n_grid_del * n_grid_reg),
            mass_cus = intermediate_info["mass_cus"][:],
            mass_sup = intermediate_info["mass_sup"][:],
            int_sold_q = intermediate_info["intermediates_sold_q"][:],
            int_purch_q = intermediate_info["intermediates_purchased_q"][:],
            int_sold_val = intermediate_info["intermediates_sold_val"][:],
            int_purch_val = intermediate_info["intermediates_purchased_val"][:], 
            g_chi = g_chi_guess[:] .* entry[:] .* loc_prob[:]
        )
        append!(output_df, f_output)
    end 
    return output_df
end


firm_output = @distributed vcat for iceberg = inter_trade_costs
   print(iceberg)
   intra_solve(iceberg, IO_flag);
end;

file_name = "main_model.csv"

if IO_flag == true
    file_name = string("IO_", file_name)
end

using CSV
CSV.write(file_name, firm_output)

# 1.125